These are taken from the mixture section of the brms reference manual.
Here we simulate our data, dat.
library(tidyverse)
set.seed(1234)
dat <-
tibble(y = c(rnorm(200, mean = 0, sd = 1),
rnorm(100, mean = 6, sd = 1)),
x = rnorm(300, mean = 0, sd = 1),
z = sample(0:1, 300, replace = T))
head(dat)
Here’s what the data look like.
library(GGally)
theme_set(theme_grey() +
theme(panel.grid = element_blank()))
dat %>%
mutate(z = factor(z)) %>%
ggpairs()
fit1: A simple normal mixture modelOpen brms.
library(brms)
Fit the model.
fit1_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z),
prior = c(prior(normal(0, 7), Intercept, dpar = mu1),
prior(normal(5, 7), Intercept, dpar = mu2)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit1_s2 <-
update(fit1_s1,
seed = 2)
fit1_s3 <-
update(fit1_s1,
seed = 3)
fit1_s4 <-
update(fit1_s1,
seed = 4)
If you’d like to inspect all those chains, you can use the plot() funciton, as usual. Since we’re working in bulk, it might make sense to condense our diagnostics to \(\hat R\) plots via the bayesplot package.
library(bayesplot)
library(gridExtra)
p1 <-
rhat(fit1_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit1_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit1_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit1_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Recall we like our \(\hat R\) values to hover around 1. For the models from each seed, those are just a disaster. Let’s take a peek at the chains from just two of the fits to get a sense of the damage.
posterior_samples(fit1_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit1_s4, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 4") +
theme(legend.position = "top")
Where as many of the chains in fit1_s1 appeared to wildly meander across the parameter space, The parallel chains in fit1_s4 seemed to stabilize on alternative parameter spaces. I believe this is often called the label switching problem (e.g., see here). Either way, the resulting \(\hat R\) values were awful.
For our first attempt at fixing the issue, we might tighten up the priors. Of our three variables, two are standardized and the third is a dummy. It wouldn’t be unreasonable to \(\sigma = 1\) Gaussians on all intercepts, \(\beta\)s, and even the model \(\sigma\)s themselves.
fit2_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z),
prior = c(prior(normal(0, 1), Intercept, dpar = mu1),
prior(normal(5, 1), Intercept, dpar = mu2),
prior(normal(0, 1), class = b, dpar = mu1),
prior(normal(0, 1), class = b, dpar = mu2),
prior(normal(0, 1), class = sigma1),
prior(normal(0, 1), class = sigma2)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit2_s2 <-
update(fit2_s1,
seed = 2)
fit2_s3 <-
update(fit2_s1,
seed = 3)
fit2_s4 <-
update(fit2_s1,
seed = 4)
Check the \(\hat R\) values.
p1 <-
rhat(fit2_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit2_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit2_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit2_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
They only look good for 1 on 4. Not very encouraging. Let’s revisit the chains for seed = 1 and now inspect the better-looing seed = 2.
posterior_samples(fit2_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit2_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
Well, the chains for seed = 1 aren’t wildly flailing across ridiculous areas of the parameter space anymore. But they show the same odd parallel behavior like those from seed = 4 in our first attempt. At least the chains from seed = 2 have given us hope. If we were lazy, we’d just go ahead and use those. But man, that seems like a risky workflow, to me. I’d like a more stable solution.
order = "mu"fit3_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian, order = "mu"),
bf(y ~ x + z),
prior = c(prior(normal(0, 7), Intercept, dpar = mu1),
prior(normal(5, 7), Intercept, dpar = mu2)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit3_s2 <-
update(fit3_s1,
seed = 2)
fit3_s3 <-
update(fit3_s1,
seed = 3)
fit3_s4 <-
update(fit3_s1,
seed = 4)
What do the \(\hat R\) values tell us?
p1 <-
rhat(fit3_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit3_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit3_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit3_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Nope, using order = "mu" didn’t solve the problem. Let’s confirm by looking at the chains.
posterior_samples(fit3_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit3_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 4") +
theme(legend.position = "top")
order = "mu" in addition to better priorsHere we combine order = "mu" to the models with the tighter priors from the second attempt.
fit4_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian, order = "mu"),
bf(y ~ x + z),
prior = c(prior(normal(0, 1), Intercept, dpar = mu1),
prior(normal(5, 1), Intercept, dpar = mu2),
prior(normal(0, 1), class = b, dpar = mu1),
prior(normal(0, 1), class = b, dpar = mu2),
prior(normal(0, 1), class = sigma1),
prior(normal(0, 1), class = sigma2)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit4_s2 <-
update(fit4_s1,
seed = 2)
fit4_s3 <-
update(fit4_s1,
seed = 3)
fit4_s4 <-
update(fit4_s1,
seed = 4)
How do the \(\hat R\) values look now?
p1 <-
rhat(fit4_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit4_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit4_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit4_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Still failed on 3/4. We need a better solution. Here are some of the chains.
posterior_samples(fit4_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit4_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
The label switching persists.
We’ll reduce those Gaussian \(\sigma\)s to 0.5.
fit5_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian, order = "mu"),
bf(y ~ x + z),
prior = c(prior(normal(0, .5), Intercept, dpar = mu1),
prior(normal(5, .5), Intercept, dpar = mu2),
prior(normal(0, 1), class = b, dpar = mu1),
prior(normal(0, 1), class = b, dpar = mu2),
prior(normal(0, 1), class = sigma1),
prior(normal(0, 1), class = sigma2)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit5_s2 <-
update(fit5_s1,
seed = 2)
fit5_s3 <-
update(fit5_s1,
seed = 3)
fit5_s4 <-
update(fit5_s1,
seed = 4)
How do the \(\hat R\) values look now?
p1 <-
rhat(fit5_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit5_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit5_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit5_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Success! I feel so whipped from the previous versions, let’s just examine some of the chains to make sure it’s all good.
posterior_samples(fit5_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit5_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
Oh mamma. Those are some sweet chains. So we learned a lesson. One reasonably reliable solution to the label switching problem is if we hold the model’s hand with tight priors on the intercept, or presumably the other parameters we expect substantial differences in. I’m definitely not entirely happy with this method. It seems heavier-handed than I prefer.
But anyways, let’s look at the model summary.
print(fit5_s1)
## Family: mixture(gaussian, gaussian)
## Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity
## Formula: y ~ x + z
## Data: dat (Number of observations: 300)
## Samples: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup samples = 2000
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## mu1_Intercept 0.05 0.11 -0.16 0.26 1.00 3231 1666
## mu2_Intercept 6.17 0.13 5.92 6.41 1.00 2973 1541
## mu1_x 0.05 0.07 -0.08 0.18 1.00 3735 1286
## mu1_z -0.17 0.15 -0.45 0.12 1.00 3746 1473
## mu2_x -0.08 0.10 -0.29 0.12 1.00 3753 1241
## mu2_z -0.11 0.18 -0.46 0.26 1.00 3728 1349
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma1 1.04 0.05 0.94 1.15 1.00 3767 1573
## sigma2 0.93 0.07 0.80 1.09 1.00 3460 1600
## theta1 0.67 0.03 0.61 0.72 1.00 4489 1473
## theta2 0.33 0.03 0.28 0.39 1.00 4489 1473
##
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample
## is a crude measure of effective sample size, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
The parameter estimates look great. And yeah, it makes sense this was a difficult model to fit. It was only the intercepts that varied across the two classes. Everything else was basically the same.
Well, okay, those \(\theta\) parameters differed. Which, by the way, leads one to nail down precisely what they are. They look a lot like proportions. If so, the \(\theta\)s should always sum to 1. Let’s see.
posterior_samples(fit5_s1) %>%
transmute(theta_sum = theta1 + theta2) %>%
range()
## [1] 1 1
Yep, they always sum to 1, consistent with a proportion interpretation. Based on the combination of the intercepts and \(\theta\)s, the model is telling us the intercept was about 0 for 2/3 of the cases. We can confirm that’s correct with a quick refresher at the simulation code:
dat <-
tibble(y = c(rnorm(200, mean = 0, sd = 1),
rnorm(100, mean = 6, sd = 1)),
x = rnorm(300, mean = 0, sd = 1),
z = sample(0:1, 300, replace = T))
Yep, for y, 200 of the total 300 cases were simulated based on the standard Gaussian.
Let’s finish out the code from the reference manual and do a posterior predictive check.
pp_check(fit5_s1,
nsamples = 20)
Looks great!
rdirichlet(n = 1e4,
alpha = c(20, 10)) %>%
data.frame() %>%
gather() %>%
ggplot(aes(x = value, fill = key)) + geom_density(size = 0, alpha = 1/2) + xlim(0, 1)
fit6_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z),
prior = c(prior(normal(0, 7), Intercept, dpar = mu1),
prior(normal(5, 7), Intercept, dpar = mu2),
prior(dirichlet(20, 10), theta)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit6_s2 <-
update(fit6_s1,
seed = 2)
fit6_s3 <-
update(fit6_s1,
seed = 3)
fit6_s4 <-
update(fit6_s1,
seed = 4)
Check the rhats.
p1 <-
rhat(fit6_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit6_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit6_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit6_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Still only one of 4 looks good.
posterior_samples(fit6_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit6_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
rdirichlet(n = 1e4,
alpha = c(200, 100)) %>%
data.frame() %>%
gather() %>%
ggplot(aes(x = value, fill = key)) + geom_density(size = 0, alpha = 1/2) + xlim(0, 1)
fit7_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z),
prior = c(prior(normal(0, 7), Intercept, dpar = mu1),
prior(normal(5, 7), Intercept, dpar = mu2),
prior(dirichlet(200, 100), theta)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit7_s2 <-
update(fit7_s1,
seed = 2)
fit7_s3 <-
update(fit7_s1,
seed = 3)
fit7_s4 <-
update(fit7_s1,
seed = 4)
Check the rhats.
p1 <-
rhat(fit7_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit7_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit7_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit7_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
fit8_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z),
prior = c(prior(normal(0, 1), Intercept, dpar = mu1),
prior(normal(5, 1), Intercept, dpar = mu2),
prior(normal(0, 1), class = b, dpar = mu1),
prior(normal(0, 1), class = b, dpar = mu2),
prior(normal(0, 1), class = sigma1),
prior(normal(0, 1), class = sigma2),
prior(dirichlet(200, 100), theta)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit8_s2 <-
update(fit8_s1,
seed = 2)
fit8_s3 <-
update(fit8_s1,
seed = 3)
fit8_s4 <-
update(fit8_s1,
seed = 4)
rhats
p1 <-
rhat(fit8_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit8_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit8_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit8_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Still largely a disaster.
posterior_samples(fit8_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit8_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
fit9_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z,
sigma2 = 'sigma1'),
prior = c(prior(normal(0, 1), Intercept, dpar = mu1),
prior(normal(5, 1), Intercept, dpar = mu2),
prior(normal(0, 1), class = b, dpar = mu1),
prior(normal(0, 1), class = b, dpar = mu2),
prior(normal(0, 1), class = sigma1),
prior(dirichlet(200, 100), theta)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit9_s2 <-
update(fit9_s1,
seed = 2)
fit9_s3 <-
update(fit9_s1,
seed = 3)
fit9_s4 <-
update(fit9_s1,
seed = 4)
p1 <-
rhat(fit9_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit9_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit9_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit9_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Success!
posterior_samples(fit9_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit9_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
fit10_s1 <-
brm(data = dat,
family = mixture(gaussian, gaussian),
bf(y ~ x + z,
sigma2 = 'sigma1'),
prior = c(prior(normal(0, 1), Intercept, dpar = mu1),
prior(normal(5, 1), Intercept, dpar = mu2),
prior(normal(0, 1), class = b, dpar = mu1),
prior(normal(0, 1), class = b, dpar = mu2),
prior(normal(0, 1), class = sigma1),
prior(dirichlet(20, 10), theta)),
iter = 2000, warmup = 1000, chains = 2, cores = 2,
seed = 1)
fit10_s2 <-
update(fit10_s1,
seed = 2)
fit10_s3 <-
update(fit10_s1,
seed = 3)
fit10_s4 <-
update(fit10_s1,
seed = 4)
rhats
p1 <-
rhat(fit10_s1) %>%
mcmc_rhat()
p2 <-
rhat(fit10_s2) %>%
mcmc_rhat()
p3 <-
rhat(fit10_s3) %>%
mcmc_rhat()
p4 <-
rhat(fit10_s4) %>%
mcmc_rhat()
grid.arrange(p1, p2, p3, p4, ncol = 2)
Still a success!
posterior_samples(fit10_s1, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 1") +
theme(legend.position = "top")
posterior_samples(fit10_s2, add_chain = T) %>%
select(-lp__, -iter) %>%
mcmc_trace(facet_args = list(ncol = 5)) +
ggtitle("seed = 2") +
theme(legend.position = "top")
sessionInfo()
## R version 3.6.0 (2019-04-26)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS High Sierra 10.13.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] gridExtra_2.3 bayesplot_1.7.0 brms_2.10.0 Rcpp_1.0.2
## [5] GGally_1.4.0 forcats_0.4.0 stringr_1.4.0 dplyr_0.8.1
## [9] purrr_0.3.2 readr_1.3.1 tidyr_0.8.3 tibble_2.1.3
## [13] ggplot2_3.2.1 tidyverse_1.2.1
##
## loaded via a namespace (and not attached):
## [1] nlme_3.1-139 matrixStats_0.54.0 xts_0.11-2
## [4] lubridate_1.7.4 RColorBrewer_1.1-2 threejs_0.3.1
## [7] httr_1.4.0 rstan_2.19.2 tools_3.6.0
## [10] backports_1.1.4 R6_2.4.0 DT_0.7
## [13] lazyeval_0.2.2 colorspace_1.4-1 withr_2.1.2
## [16] prettyunits_1.0.2 processx_3.3.1 tidyselect_0.2.5
## [19] Brobdingnag_1.2-6 compiler_3.6.0 cli_1.1.0
## [22] rvest_0.3.4 xml2_1.2.0 shinyjs_1.0
## [25] labeling_0.3 colourpicker_1.0 scales_1.0.0
## [28] dygraphs_1.1.1.6 mvtnorm_1.0-11 callr_3.2.0
## [31] ggridges_0.5.1 StanHeaders_2.18.1-10 digest_0.6.20
## [34] rmarkdown_1.13 base64enc_0.1-3 pkgconfig_2.0.2
## [37] htmltools_0.3.6 htmlwidgets_1.3 rlang_0.4.0
## [40] readxl_1.3.1 rstudioapi_0.10 shiny_1.3.2
## [43] generics_0.0.2 zoo_1.8-6 jsonlite_1.6
## [46] crosstalk_1.0.0 gtools_3.8.1 inline_0.3.15
## [49] magrittr_1.5 loo_2.1.0 Matrix_1.2-17
## [52] munsell_0.5.0 abind_1.4-5 stringi_1.4.3
## [55] yaml_2.2.0 pkgbuild_1.0.3 plyr_1.8.4
## [58] grid_3.6.0 parallel_3.6.0 promises_1.0.1
## [61] crayon_1.3.4 miniUI_0.1.1.1 lattice_0.20-38
## [64] haven_2.1.0 hms_0.4.2 ps_1.3.0
## [67] knitr_1.23 pillar_1.4.2 igraph_1.2.4.1
## [70] markdown_1.0 shinystan_2.5.0 codetools_0.2-16
## [73] stats4_3.6.0 reshape2_1.4.3 rstantools_1.5.1
## [76] glue_1.3.1 evaluate_0.14 modelr_0.1.4
## [79] httpuv_1.5.1 cellranger_1.1.0 gtable_0.3.0
## [82] reshape_0.8.8 assertthat_0.2.1 xfun_0.8
## [85] mime_0.7 xtable_1.8-4 broom_0.5.2
## [88] coda_0.19-2 later_0.8.0 rsconnect_0.8.13
## [91] shinythemes_1.1.2 bridgesampling_0.6-0